import pickle
import yaml
import os
import math
import re
import numpy as np
import pandas as pd
import category_encoders as ce

from torch.utils.data import DataLoader, Dataset

import ipdb

def process_func(path: str, cat_list, missing_ratio=0.1, encode=True):
    # data = pd.read_csv(path, header=None)
    # data = data.drop(0)
    data = pd.read_csv(path)
    data = data.iloc[:, [7, 18, 26, 27, 28, 32, 33, 34, 35, 36, 37, 41, 42, 43, 44, 45, 46, 47, 48, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 78, 79, 80, 82, 83, 85, 89, 94, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 20, 115, 116]]
    

    data['ed_los'] = pd.to_timedelta(data['ed_los']).dt.seconds / 60

    # ipdb.set_trace()
    # ipdb.set_trace()
 
    # ipdb.set_trace()
    # create two lists to store position
    cont_list = [i for i in range(1, data.shape[1])]
    cat_list = [0]#['gender']
    # ['age', 'n_ed_30d', 'n_ed_90d', 'n_ed_365d', 'n_hosp_30d',
    #    'n_hosp_90d', 'n_hosp_365d', 'n_icu_30d', 'n_icu_90d', 'n_icu_365d',
    #    'triage_temperature', 'triage_heartrate', 'triage_resprate',
    #    'triage_o2sat', 'triage_sbp', 'triage_dbp', 'triage_pain',
    #    'triage_acuity', 'chiefcom_chest_pain', 'chiefcom_abdominal_pain',
    #    'chiefcom_headache', 'chiefcom_shortness_of_breath',
    #    'chiefcom_back_pain', 'chiefcom_cough', 'chiefcom_nausea_vomiting',
    #    'chiefcom_fever_chills', 'chiefcom_syncope', 'chiefcom_dizziness',
    #    'cci_MI', 'cci_CHF', 'cci_PVD', 'cci_Stroke', 'cci_Dementia',
    #    'cci_Pulmonary', 'cci_Rheumatic', 'cci_PUD', 'cci_Liver1', 'cci_DM1',
    #    'cci_DM2', 'cci_Paralysis', 'cci_Renal', 'cci_Cancer1', 'cci_Liver2',
    #    'cci_Cancer2', 'cci_HIV', 'eci_Arrhythmia', 'eci_Valvular', 'eci_PHTN',
    #    'eci_HTN1', 'eci_HTN2', 'eci_NeuroOther', 'eci_Hypothyroid',
    #    'eci_Lymphoma', 'eci_Coagulopathy', 'eci_Obesity', 'eci_WeightLoss',
    #    'eci_FluidsLytes', 'eci_BloodLoss', 'eci_Anemia', 'eci_Alcohol',
    #    'eci_Drugs', 'eci_Psychoses', 'eci_Depression', 'ed_temperature_last',
    #    'ed_heartrate_last', 'ed_resprate_last', 'ed_o2sat_last', 'ed_sbp_last',
    #    'ed_dbp_last', 'ed_los', 'n_med', 'n_medrecon']

    observed_values = data.values
    observed_masks = ~pd.isnull(data)
    observed_masks = observed_masks.values

    masks = observed_masks.copy()

    '''
    original code
    # In this section, obtain gt_masks
    # for each column, mask {missing_ratio} % of observed values.
    for col in range(masks.shape[1]):  # col #
        obs_indices = np.where(masks[:, col])[0]
        miss_indices = np.random.choice(
            obs_indices, (int)(len(obs_indices) * missing_ratio), replace=False
        )
        masks[miss_indices, col] = False
    '''
    # for each row, mask {missing_ratio} % of observed values. Using for learning from missing data
    # for col in range(observed_values.shape[0]):  # row #
    #     obs_indices = np.where(masks[col,: ])[0]
    #     miss_indices = np.random.choice(
    #         obs_indices, (int)(len(obs_indices) * missing_ratio), replace=False
    #     )
    #     observed_masks[col, miss_indices] = False
    need_masks = np.random.choice([0, 1],size=observed_values.shape, p=[missing_ratio, 1-missing_ratio])
    # ipdb.set_trace()
    # observed_masks = observed_masks * need_masks

    masks = observed_masks.copy()

    
    observed_values = observed_values * observed_masks  #using 0 for imputation
    '''
    # using mean for imputation

    #even for mean imputation, we still need 
    observed_values = observed_values * observed_masks

    observed_values[observed_values == 0] = np.nan
    mean_for_imputation = np.nanmean(observed_values, axis=0)
    missing_index = np.argwhere(np.isnan(observed_values))
    for _, (x,y) in enumerate(missing_index):
        observed_values[x,y] = mean_for_imputation[y]
    '''

    # for col in range(observed_values.shape[0]):  # row #
    #     obs_indices = np.where(masks[col,: ])[0]
    #     miss_indices = np.random.choice(
    #         obs_indices, (int)(len(obs_indices) * missing_ratio), replace=False
    #     )
    #     masks[col,miss_indices ] = False

    # gt_mask: 0 for missing elements and manully maksed elements
    gt_masks = masks.reshape(observed_masks.shape)
    
    if encode == True:
        # set encoder here
        encoder = ce.one_hot.OneHotEncoder(cols=data.columns[cat_list])
        encoder.fit(data)
        new_df = encoder.transform(data)

        # ipdb.set_trace()

        # we now need to transform these masks to the new one, suitable for mixed data types.
        cum_num_bits = 0
        new_observed_masks = observed_masks.copy()
        new_gt_masks = gt_masks.copy()

        for index, col in enumerate(cat_list):

            corresponding_cols = len(
                [
                    s
                    for s in new_df.columns
                    if isinstance(s, str) and s.startswith('gender' + "_")
                ]
            )
            add_col_num = corresponding_cols
            insert_col_obs = observed_masks[:, col]
            insert_col_gt = gt_masks[:, col]

            for i in range(add_col_num - 1):
                new_observed_masks = np.insert(
                    new_observed_masks, cum_num_bits + col, insert_col_obs, axis=1
                )
                new_gt_masks = np.insert(
                    new_gt_masks, cum_num_bits + col, insert_col_gt, axis=1
                )
            cum_num_bits += add_col_num - 1

        new_observed_values = new_df.values
        new_observed_values_cat = new_observed_values[:, len(cont_list) :]
        index = new_observed_values_cat == 0
        new_observed_values_cat[index] = -1
        new_observed_values[:, len(cont_list) :] = new_observed_values_cat
        new_observed_values = np.nan_to_num(new_observed_values)
        new_observed_values = new_observed_values.astype(np.float)

        saved_cat_dict = {}
        for index, col in enumerate(cat_list):
            # ipdb.set_trace()
            indices = [
                i
                for i, s in enumerate(new_df.columns)
                if isinstance(s, str) and s.startswith('gender')
            ]
            saved_cat_dict[str(cat_list[index])] = indices

        with open("./data_MIMIC4ED_onehot/transformed_columns.pk", "wb") as f:
            pickle.dump([cont_list, saved_cat_dict], f)

        with open("./data_MIMIC4ED_onehot/encoder.pk", "wb") as f:
            pickle.dump(encoder, f)

    if encode == True:
        # ipdb.set_trace()
        new_observed_values = new_observed_values * new_observed_masks#现在在这里尝试做这么一件事儿 看看结果怎么样
        return new_observed_values, new_observed_masks, new_gt_masks, cont_list
    else:
        cont_cols = [i for i in data.columns if i not in cat_list]
        return observed_values, observed_masks, gt_masks, cont_list


class tabular_dataset(Dataset):
    def __init__(self, eval_length=74, use_index_list=None, missing_ratio=0.1, seed=0):
        self.eval_length = eval_length
        np.random.seed(seed)

        dataset_path = "./data_MIMIC4ED_onehot/train.csv"
        processed_data_path = (
            f"./data_MIMIC4ED_onehot/missing_ratio-{missing_ratio}_seed-{seed}.pk"
        )
        processed_data_path_norm = f"./data_MIMIC4ED_onehot/missing_ratio-{missing_ratio}_seed-{seed}_max-min_norm.pk"

        # self.cont_cols is only saved in .pk file before normalization.
        cat_list = [0]
        if not os.path.isfile(processed_data_path):
            (
                self.observed_values,
                self.observed_masks,
                self.gt_masks,
                self.cont_cols,
            ) = process_func(
                dataset_path,
                cat_list=cat_list,
                missing_ratio=missing_ratio,
                encode=True,
            )

            with open(processed_data_path, "wb") as f:
                pickle.dump(
                    [
                        self.observed_values,
                        self.observed_masks,
                        self.gt_masks,
                        self.cont_cols,
                    ],
                    f,
                )
            print("--------Dataset created--------")

        elif os.path.isfile(processed_data_path_norm):  # load datasetfile
            with open(processed_data_path_norm, "rb") as f:
                self.observed_values, self.observed_masks, self.gt_masks = pickle.load(
                    f
                )
            print("--------Normalized dataset loaded--------")

        if use_index_list is None:
            self.use_index_list = np.arange(len(self.observed_values))
        else:
            self.use_index_list = use_index_list

    def __getitem__(self, org_index):
        index = self.use_index_list[org_index]
        s = {
            "observed_data": self.observed_values[index],
            "observed_mask": self.observed_masks[index],
            "gt_mask": self.gt_masks[index],
            "timepoints": np.arange(self.eval_length),
        }
        return s

    def __len__(self):
        return len(self.use_index_list)


def get_dataloader(seed=1, nfold=5, batch_size=16, missing_ratio=0.1):
    dataset = tabular_dataset(missing_ratio=missing_ratio, seed=seed)
    print(f"Dataset size:{len(dataset)} entries")

    indlist = np.arange(len(dataset))

    np.random.seed(seed + 1)
    np.random.shuffle(indlist)
    remain_index = indlist

    num_train = (int)(len(remain_index) * 1)
    train_index = remain_index[:num_train]

    # Here we perform max-min normalization.
    processed_data_path_norm = f"./data_MIMIC4ED_onehot/missing_ratio-{missing_ratio}_seed-{seed}_max-min_norm.pk"
    if not os.path.isfile(processed_data_path_norm):
        print(
            "--------------Dataset has not been normalized yet. Perform data normalization and store the mean value of each column.--------------"
        )
        # data transformation after train-test split.
        col_num = len(dataset.cont_cols)
        max_arr = np.zeros(col_num)
        min_arr = np.zeros(col_num)
        mean_arr = np.zeros(col_num)
        for index, k in enumerate(dataset.cont_cols):
            # Using observed_mask to avoid including missing values (now represented as 0)
            obs_ind = dataset.observed_masks[train_index, k].astype(bool)
            temp = dataset.observed_values[train_index, k]
            max_arr[index] = max(temp[obs_ind])
            min_arr[index] = min(temp[obs_ind])

        print(
            f"--------------Max-value for cont-variable column {max_arr}--------------"
        )
        print(
            f"--------------Min-value for cont-variable column {min_arr}--------------"
        )

        for index, k in enumerate(dataset.cont_cols):
            dataset.observed_values[:, k] = (
                (dataset.observed_values[:, k] - (min_arr[index] - 1))
                / (max_arr[index] - min_arr[index] + 1)
            ) * dataset.observed_masks[:, k]

        with open(processed_data_path_norm, "wb") as f:
            pickle.dump(
                [dataset.observed_values, dataset.observed_masks, dataset.gt_masks], f
            )

    # Create datasets and corresponding data loaders objects.
    train_dataset = tabular_dataset(
        use_index_list=train_index, missing_ratio=missing_ratio, seed=seed
    )
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=1)

    valid_loader = None

    test_loader = None

    print(f"Training dataset size: {len(train_dataset)}")

    return train_loader, valid_loader, test_loader
